# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Global variables and functions for TF/Keras compatibility."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import os
import weakref

import tensorflow as tf


def _get_keras_instance():
  # Keep using keras-2 (tf-keras) rather than keras-3 (keras).
  os.environ['TF_USE_LEGACY_KERAS'] = '1'

  # Use Keras 2.
  version_fn = getattr(tf.keras, 'version', None)
  if version_fn and version_fn().startswith('3.'):
    import tf_keras as keras_internal  # pylint: disable=g-import-not-at-top,unused-import
  else:
    keras_internal = tf.keras
  return keras_internal


keras = _get_keras_instance()

def assign(ref, value, name=None):
  if hasattr(tf, 'assign'):
    return tf.assign(ref, value, name=name)
  else:
    return ref.assign(value, name=name)


def initialize_variables(testcase):
  """Handle global variable initialization in TF 1.X.

  Arguments:
    testcase: instance of tf.test.TestCase
  """
  if hasattr(tf, 'global_variables_initializer') and not tf.executing_eagerly():
    testcase.evaluate(tf.global_variables_initializer())


def is_v1_apis():
  return hasattr(tf, 'assign')


# A global dictionary mapping graph objects to an index of counters used
# for various layer/optimizer names in each graph.
# Allows to give unique autogenerated names to layers, in a graph-specific way.
PER_GRAPH_OBJECT_NAME_UIDS = weakref.WeakKeyDictionary()


def get_default_graph_uid_map():
  graph = tf.compat.v1.get_default_graph()
  name_uid_map = PER_GRAPH_OBJECT_NAME_UIDS.get(graph, None)
  if name_uid_map is None:
    name_uid_map = collections.defaultdict(int)
    PER_GRAPH_OBJECT_NAME_UIDS[graph] = name_uid_map
  return name_uid_map


def unique_object_name(
    name,
    name_uid_map=None,
    avoid_names=None,
    namespace='',
    zero_based=False,
    avoid_observed_names=False,
):
  """Makes a object name (or any string) unique within a TF-Keras session.

  Args:
    name: String name to make unique.
    name_uid_map: An optional defaultdict(int) to use when creating unique
      names. If None (default), uses a per-Graph dictionary.
    avoid_names: An optional set or dict with names which should not be used. If
      None (default), don't avoid any names unless `avoid_observed_names` is
      True.
    namespace: Gets a name which is unique within the (graph, namespace). Layers
      which are not Networks use a blank namespace and so get graph-global
      names.
    zero_based: If True, name sequences start with no suffix (e.g. "dense",
      "dense_1"). If False, naming is one-based ("dense_1", "dense_2").
    avoid_observed_names: If True, avoid any names that have been observed by
      `backend.observe_object_name`.

  Returns:
    Unique string name.

  Example:


  unique_object_name('dense')  # dense_1
  unique_object_name('dense')  # dense_2
  """
  if name_uid_map is None:
    name_uid_map = get_default_graph_uid_map()
  if avoid_names is None:
    if avoid_observed_names:
      avoid_names = OBSERVED_NAMES
    else:
      avoid_names = set()
  proposed_name = None
  while proposed_name is None or proposed_name in avoid_names:
    name_key = (namespace, name)
    if zero_based:
      number = name_uid_map[name_key]
      if number:
        proposed_name = name + '_' + str(number)
      else:
        proposed_name = name
      name_uid_map[name_key] += 1
    else:
      name_uid_map[name_key] += 1
      proposed_name = name + '_' + str(name_uid_map[name_key])
  return proposed_name
